/*
@file: mat.h
@author: ZZH
@time: 2022-10-17 16:52:10
@info: 4X4矩阵
*/
#pragma once
#include <iostream>
#include <cstdint>
#include <cstring>
#include <cstdio>

namespace AES {
    class Mat_t
    {
    private:
        using uint8_t = std::uint8_t;

        static const uint8_t sBox[256];
        static const uint8_t isBox[256];
        static const uint8_t mixer[4][4];
        static const uint8_t imixer[4][4];

        uint8_t mat[4][4];

        uint8_t mix(uint8_t mix, uint8_t data);
        void mix(const uint8_t(*pMixer)[4]);

        inline void replace(const uint8_t* pBox)
        {
            for (auto& l : this->mat)
            {
                for (auto& d : l)
                    d = pBox[d];
            }
        }

    public:
        explicit Mat_t() {}
        Mat_t(const uint8_t data) { memset(this->mat, data, 16); }
        explicit Mat_t(const uint8_t* pMat, uint8_t len = 16) { this->setData(pMat, len); }
        Mat_t(const std::initializer_list<uint8_t>& l) { this->setData(l.begin(), l.size()); }
        Mat_t(const Mat_t& other) { memcpy(this->mat, other.mat, sizeof(this->mat)); }
        ~Mat_t() {}

        void setData(const uint8_t* pData, uint8_t len = 16);//从数组导入
        void getData(uint8_t* output, uint8_t len = 16) const;//导出到数组

        inline void fill(uint8_t data) { memset(this->mat, data, sizeof(this->mat)); }

        inline void set(uint8_t x, uint8_t y, uint8_t value)
        {
            if (x > 3)
                x = 3;

            if (y > 3)
                y = 3;

            this->mat[y][x] = value;
        }

        inline void set(uint8_t pos, uint8_t value)
        {
            if (pos > 15)
                pos = 15;

            uint8_t x = pos / 4;
            uint8_t y = pos % 4;
            this->mat[y][x] = value;
        }

        inline uint8_t get(uint8_t x, uint8_t y)
        {
            if (x > 3)
                x = 3;

            if (y > 3)
                y = 3;

            return this->mat[y][x];
        }

        inline uint8_t get(uint8_t pos)
        {
            if (pos > 15)
                pos = 15;

            uint8_t x = pos / 4;
            uint8_t y = pos % 4;
            return this->mat[y][x];
        }

        void operator ^= (const Mat_t& other)
        {
            for (int i = 0;i < 4;i++)
                for (int j = 0;j < 4;j++)
                    this->mat[i][j] ^= other.mat[i][j];
        }

        friend Mat_t operator ^ (const Mat_t& left, const Mat_t& right)
        {
            Mat_t res;
            for (int i = 0;i < 4;i++)
                for (int j = 0;j < 4;j++)
                    res.mat[i][j] = left.mat[i][j] ^ right.mat[i][j];
            return res;
        }

        friend std::ostream& operator << (std::ostream& ost, const Mat_t& mat)
        {
            char lineBuf[20];

            for (int i = 0;i < 4;i++)
            {
                snprintf(lineBuf, sizeof(lineBuf), "[%02X %02X %02X %02X]\r\n",
                    mat.mat[i][0], mat.mat[i][1], mat.mat[i][2], mat.mat[i][3]
                );
                lineBuf[19] = '\0';
                ost << lineBuf;
            }
            return ost;
        }

        void transpose(void);//转置当前矩阵

        inline void byteReplace(void) { this->replace(this->sBox); }//字节代换
        inline void ibyteReplace(void) { this->replace(this->isBox); }//反字节代换

        void shiftRows(void);//行移位
        void ishiftRows(void);//逆行移位

        inline void mixColumn(void) { this->mix(this->mixer); }//列混淆
        inline void imixColumn(void) { this->mix(this->imixer); }//逆列混淆

        inline uint8_t getUnFilledLen(void) const { return 16 - this->mat[3][3]; }//获取实际长度
        inline void getUnFilledPart(uint8_t* output) const { this->getData(output, this->getUnFilledLen()); }//获取未填充的部分
    };
}
